{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch2trt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.cuda().eval().half()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelWrapper(torch.nn.Module):\n",
    "    def __init__(self, model):\n",
    "        super(ModelWrapper, self).__init__()\n",
    "        self.model = model\n",
    "    def forward(self, x):\n",
    "        return self.model(x)['out']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_w = ModelWrapper(model).half()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.ones((1, 3, 224, 224)).cuda().half()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_trt = torch2trt.torch2trt(model_w, [data], fp16_mode=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Live demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from jetcam.csi_camera import CSICamera\n",
    "from jetcam.usb_camera import USBCamera\n",
    "\n",
    "# camera = CSICamera(width=224, height=224)\n",
    "camera = USBCamera(width=224, height=224)\n",
    "\n",
    "camera.running = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetcam.utils import bgr8_to_jpeg\n",
    "import traitlets\n",
    "import ipywidgets\n",
    "\n",
    "image_w = ipywidgets.Image()\n",
    "\n",
    "traitlets.dlink((camera, 'value'), (image_w, 'value'), transform=bgr8_to_jpeg)\n",
    "\n",
    "display(image_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "import torchvision\n",
    "\n",
    "device = torch.device('cuda')\n",
    "mean = 255.0 * np.array([0.485, 0.456, 0.406])\n",
    "stdev = 255.0 * np.array([0.229, 0.224, 0.225])\n",
    "\n",
    "normalize = torchvision.transforms.Normalize(mean, stdev)\n",
    "\n",
    "def preprocess(camera_value):\n",
    "    global device, normalize\n",
    "    x = camera_value\n",
    "    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)\n",
    "    x = x.transpose((2, 0, 1))\n",
    "    x = torch.from_numpy(x).float()\n",
    "    x = normalize(x)\n",
    "    x = x.to(device)\n",
    "    x = x[None, ...]\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seg_image = ipywidgets.Image()\n",
    "\n",
    "display(seg_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def execute(change):\n",
    "    image = change['new']\n",
    "    output = model_trt(preprocess(camera.value).half())[0].detach().cpu().float().numpy()\n",
    "    mask = 1.0 * (output.argmax(0) == 15)\n",
    "    seg_image.value = bgr8_to_jpeg(mask[:, :, None] * image)\n",
    "    \n",
    "    \n",
    "mask = execute({'new': camera.value})\n",
    "# camera.observe(execute, names='value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "camera.observe(execute, names='value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "camera.unobserve(execute, names='value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "torch.cuda.current_stream().synchronize()\n",
    "t0 = time.time()\n",
    "for i in range(100):\n",
    "    output = model_w(preprocess(camera.value).half())\n",
    "torch.cuda.current_stream().synchronize()\n",
    "t1 = time.time()\n",
    "\n",
    "print(100.0 / (t1 - t0))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
